{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Stacking in PyTorch Tabular\n",
    "\n",
    "This page demonstrates how to use model stacking functionality in PyTorch Tabular to combine multiple models for better predictions.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## Setup and Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.model_selection import train_test_split\n",
    "from pytorch_tabular import TabularModel\n",
    "from pytorch_tabular.models import (\n",
    "CategoryEmbeddingModelConfig,\n",
    "FTTransformerConfig,\n",
    "TabNetModelConfig\n",
    ")\n",
    "from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig\n",
    "from pytorch_tabular.models.stacking import StackingModelConfig\n",
    "from pytorch_tabular.utils import make_mixed_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create synthetic classification dataset & split into train, validation and test sets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data, cat_col_names, num_col_names = make_mixed_dataset(\n",
    "    task=\"classification\", n_samples=3000, n_features=7, n_categories=4\n",
    ")\n",
    "\n",
    "train, test = train_test_split(data, random_state=42)\n",
    "train, valid = train_test_split(train, random_state=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Common configurations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_config = DataConfig(\n",
    "    target=[\"target\"],\n",
    "    continuous_cols=num_col_names,\n",
    "    categorical_cols=cat_col_names,\n",
    ")\n",
    "trainer_config = TrainerConfig(\n",
    "    batch_size=1024,\n",
    "    max_epochs=20,\n",
    "    early_stopping=\"valid_accuracy\",\n",
    "    early_stopping_mode=\"max\",\n",
    "    early_stopping_patience=3,\n",
    "    checkpoints=\"valid_accuracy\",\n",
    "    load_best=True,\n",
    ")\n",
    "optimizer_config = OptimizerConfig()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Configure individual models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_config_1 = CategoryEmbeddingModelConfig(\n",
    "    task=\"classification\",\n",
    "    layers=\"128-64-32\",\n",
    "    activation=\"ReLU\",\n",
    "    learning_rate=1e-3\n",
    ")\n",
    "model_config_2 = FTTransformerConfig(\n",
    "    task=\"classification\",\n",
    "    input_embed_dim=32,\n",
    "    num_attn_blocks=2,\n",
    "    num_heads=4,\n",
    "    learning_rate=1e-3\n",
    ")\n",
    "model_config_3 = TabNetModelConfig(\n",
    "    task=\"classification\",\n",
    "    n_d=8,\n",
    "    n_a=8,\n",
    "    n_steps=3,\n",
    "    learning_rate=1e-3\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## Configure Stacking Model\n",
    "\n",
    "Now let's set up the stacking configuration that will combine these models:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "stacking_config = StackingModelConfig(\n",
    "    task=\"classification\",\n",
    "    model_configs=[\n",
    "        model_config_1,\n",
    "        model_config_2,\n",
    "        model_config_3\n",
    "    ],\n",
    "    head=\"LinearHead\",\n",
    "    head_config={\n",
    "        \"layers\": \"64\",\n",
    "        \"activation\": \"ReLU\",\n",
    "        \"dropout\": 0.1\n",
    "    },\n",
    "    learning_rate=1e-3\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## Train Stacking Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:02:35</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">338</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">147</span><span style=\"font-weight: bold\">}</span> - INFO - Experiment Tracking is turned off           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:02:35\u001b[0m,\u001b[1;36m338\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m147\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:02:35</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">388</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">549</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the DataLoaders                   \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:02:35\u001b[0m,\u001b[1;36m388\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m549\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders                   \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:02:35</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">394</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_datamodul<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">e:527</span><span style=\"font-weight: bold\">}</span> - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:02:35\u001b[0m,\u001b[1;36m394\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:527\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:02:35</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">462</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">600</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Model: StackingModel          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:02:35\u001b[0m,\u001b[1;36m462\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m600\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: StackingModel          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:02:35</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">516</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">343</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Trainer                       \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:02:35\u001b[0m,\u001b[1;36m516\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m343\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer                       \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:02:35</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">813</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">679</span><span style=\"font-weight: bold\">}</span> - INFO - Training Started                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:02:35\u001b[0m,\u001b[1;36m813\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m679\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type                   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Mode  </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ _backbone        │ StackingBackbone       │ 77.2 K │ train │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _embedding_layer │ StackingEmbeddingLayer │    917 │ train │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ _head            │ LinearHead             │ 12.5 K │ train │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ loss             │ CrossEntropyLoss       │      0 │ train │\n",
       "└───┴──────────────────┴────────────────────────┴────────┴───────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType                  \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mMode \u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ StackingBackbone       │ 77.2 K │ train │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ StackingEmbeddingLayer │    917 │ train │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ _head            │ LinearHead             │ 12.5 K │ train │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss             │ CrossEntropyLoss       │      0 │ train │\n",
       "└───┴──────────────────┴────────────────────────┴────────┴───────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 90.6 K                                                                                           \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 90.6 K                                                                                               \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 0                                                                          \n",
       "<span style=\"font-weight: bold\">Modules in train mode</span>: 188                                                                                         \n",
       "<span style=\"font-weight: bold\">Modules in eval mode</span>: 0                                                                                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 90.6 K                                                                                           \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 90.6 K                                                                                               \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 0                                                                          \n",
       "\u001b[1mModules in train mode\u001b[0m: 188                                                                                         \n",
       "\u001b[1mModules in eval mode\u001b[0m: 0                                                                                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3cd6f3938b1f419c8b07eb89ffa13bf4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:02:39</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">304</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">692</span><span style=\"font-weight: bold\">}</span> - INFO - Training the model completed                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:02:39\u001b[0m,\u001b[1;36m304\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m692\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:02:39</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">307</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1533</span><span style=\"font-weight: bold\">}</span> - INFO - Loading the best model                     \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:02:39\u001b[0m,\u001b[1;36m307\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1533\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model                     \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<pytorch_lightning.trainer.trainer.Trainer at 0x7fb1a508d420>"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "stacking_model = TabularModel(\n",
    "    data_config=data_config,\n",
    "    model_config=stacking_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_config,\n",
    ")\n",
    "stacking_model.fit(\n",
    "    train=train,\n",
    "    validation=valid\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b1616690de674da8bbc8cc985f19686a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.5960000157356262     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.7419928312301636     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_loss_0        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.7419928312301636     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.5960000157356262    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.7419928312301636    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_loss_0       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.7419928312301636    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "predictions = stacking_model.predict(test)\n",
    "stacking_metrics = stacking_model.evaluate(test)[0]\n",
    "stacking_acc = stacking_metrics[\"test_accuracy\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compare with individual models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_and_evaluate_model(model_config, name):\n",
    "    model = TabularModel(\n",
    "    data_config=data_config,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_config,\n",
    "    )\n",
    "    model.fit(train=train, validation=valid)\n",
    "    metrics = model.evaluate(test)\n",
    "    print(f\"\\n{name} Metrics:\")\n",
    "    print(metrics)\n",
    "    return metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:01</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">257</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">147</span><span style=\"font-weight: bold\">}</span> - INFO - Experiment Tracking is turned off           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:01\u001b[0m,\u001b[1;36m257\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m147\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:01</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">320</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">549</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the DataLoaders                   \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:01\u001b[0m,\u001b[1;36m320\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m549\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders                   \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:01</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">340</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_datamodul<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">e:527</span><span style=\"font-weight: bold\">}</span> - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:01\u001b[0m,\u001b[1;36m340\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:527\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:01</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">376</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">600</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Model: CategoryEmbeddingModel \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:01\u001b[0m,\u001b[1;36m376\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m600\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: CategoryEmbeddingModel \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:01</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">411</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">343</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Trainer                       \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:01\u001b[0m,\u001b[1;36m411\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m343\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer                       \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:01</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">638</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">679</span><span style=\"font-weight: bold\">}</span> - INFO - Training Started                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:01\u001b[0m,\u001b[1;36m638\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m679\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type                      </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Mode  </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ _backbone        │ CategoryEmbeddingBackbone │ 12.1 K │ train │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _embedding_layer │ Embedding1dLayer          │     53 │ train │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ head             │ LinearHead                │     66 │ train │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ loss             │ CrossEntropyLoss          │      0 │ train │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┴───────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType                     \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mMode \u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ CategoryEmbeddingBackbone │ 12.1 K │ train │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer          │     53 │ train │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ head             │ LinearHead                │     66 │ train │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss             │ CrossEntropyLoss          │      0 │ train │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┴───────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 12.2 K                                                                                           \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 12.2 K                                                                                               \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 0                                                                          \n",
       "<span style=\"font-weight: bold\">Modules in train mode</span>: 19                                                                                          \n",
       "<span style=\"font-weight: bold\">Modules in eval mode</span>: 0                                                                                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 12.2 K                                                                                           \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 12.2 K                                                                                               \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 0                                                                          \n",
       "\u001b[1mModules in train mode\u001b[0m: 19                                                                                          \n",
       "\u001b[1mModules in eval mode\u001b[0m: 0                                                                                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "03ed36b48da24bb19f036d1db4422cb7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=20` reached.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:04</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">935</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">692</span><span style=\"font-weight: bold\">}</span> - INFO - Training the model completed                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:04\u001b[0m,\u001b[1;36m935\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m692\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:04</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">938</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1533</span><span style=\"font-weight: bold\">}</span> - INFO - Loading the best model                     \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:04\u001b[0m,\u001b[1;36m938\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1533\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model                     \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bdcb7befb3b340a895a5399394780d7e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.4586666524410248     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.8828091025352478     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_loss_0        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.8828091025352478     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.4586666524410248    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.8828091025352478    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_loss_0       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.8828091025352478    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Category Embedding Metrics:\n",
      "[{'test_loss_0': 0.8828091025352478, 'test_loss': 0.8828091025352478, 'test_accuracy': 0.4586666524410248}]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:05</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">183</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">147</span><span style=\"font-weight: bold\">}</span> - INFO - Experiment Tracking is turned off           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:05\u001b[0m,\u001b[1;36m183\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m147\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:05</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">263</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">549</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the DataLoaders                   \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:05\u001b[0m,\u001b[1;36m263\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m549\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders                   \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:05</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">272</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_datamodul<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">e:527</span><span style=\"font-weight: bold\">}</span> - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:05\u001b[0m,\u001b[1;36m272\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:527\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:05</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">294</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">600</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Model: FTTransformerModel     \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:05\u001b[0m,\u001b[1;36m294\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m600\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: FTTransformerModel     \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:05</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">323</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">343</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Trainer                       \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:05\u001b[0m,\u001b[1;36m323\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m343\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer                       \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:05</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">623</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">679</span><span style=\"font-weight: bold\">}</span> - INFO - Training Started                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:05\u001b[0m,\u001b[1;36m623\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m679\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type                  </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Mode  </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ _backbone        │ FTTransformerBackbone │ 57.7 K │ train │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _embedding_layer │ Embedding2dLayer      │    864 │ train │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ _head            │ LinearHead            │     66 │ train │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ loss             │ CrossEntropyLoss      │      0 │ train │\n",
       "└───┴──────────────────┴───────────────────────┴────────┴───────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType                 \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mMode \u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ FTTransformerBackbone │ 57.7 K │ train │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding2dLayer      │    864 │ train │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ _head            │ LinearHead            │     66 │ train │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss             │ CrossEntropyLoss      │      0 │ train │\n",
       "└───┴──────────────────┴───────────────────────┴────────┴───────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 58.6 K                                                                                           \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 58.6 K                                                                                               \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 0                                                                          \n",
       "<span style=\"font-weight: bold\">Modules in train mode</span>: 56                                                                                          \n",
       "<span style=\"font-weight: bold\">Modules in eval mode</span>: 0                                                                                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 58.6 K                                                                                           \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 58.6 K                                                                                               \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 0                                                                          \n",
       "\u001b[1mModules in train mode\u001b[0m: 56                                                                                          \n",
       "\u001b[1mModules in eval mode\u001b[0m: 0                                                                                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "62184d0ac93049058c153f2e93518d0f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:07</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">482</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">692</span><span style=\"font-weight: bold\">}</span> - INFO - Training the model completed                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:07\u001b[0m,\u001b[1;36m482\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m692\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:07</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">488</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1533</span><span style=\"font-weight: bold\">}</span> - INFO - Loading the best model                     \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:07\u001b[0m,\u001b[1;36m488\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1533\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model                     \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5a482fb9cd5045e3ada1beac9c114d97",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.5546666383743286     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.6846821904182434     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_loss_0        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.6846821904182434     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.5546666383743286    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.6846821904182434    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_loss_0       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.6846821904182434    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "FT Transformer Metrics:\n",
      "[{'test_loss_0': 0.6846821904182434, 'test_loss': 0.6846821904182434, 'test_accuracy': 0.5546666383743286}]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:07</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">824</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">147</span><span style=\"font-weight: bold\">}</span> - INFO - Experiment Tracking is turned off           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:07\u001b[0m,\u001b[1;36m824\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m147\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:07</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">863</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">549</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the DataLoaders                   \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:07\u001b[0m,\u001b[1;36m863\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m549\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders                   \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:07</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">870</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_datamodul<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">e:527</span><span style=\"font-weight: bold\">}</span> - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:07\u001b[0m,\u001b[1;36m870\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:527\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:07</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">900</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">600</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Model: TabNetModel            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:07\u001b[0m,\u001b[1;36m900\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m600\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: TabNetModel            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:07</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">965</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">343</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Trainer                       \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:07\u001b[0m,\u001b[1;36m965\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m343\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer                       \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:08</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">200</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">679</span><span style=\"font-weight: bold\">}</span> - INFO - Training Started                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:08\u001b[0m,\u001b[1;36m200\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m679\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Mode  </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ _embedding_layer │ Identity         │      0 │ train │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _backbone        │ TabNetBackbone   │  6.4 K │ train │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ _head            │ Identity         │      0 │ train │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ loss             │ CrossEntropyLoss │      0 │ train │\n",
       "└───┴──────────────────┴──────────────────┴────────┴───────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mMode \u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Identity         │      0 │ train │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ TabNetBackbone   │  6.4 K │ train │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ _head            │ Identity         │      0 │ train │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss             │ CrossEntropyLoss │      0 │ train │\n",
       "└───┴──────────────────┴──────────────────┴────────┴───────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 6.4 K                                                                                            \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 6.4 K                                                                                                \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 0                                                                          \n",
       "<span style=\"font-weight: bold\">Modules in train mode</span>: 111                                                                                         \n",
       "<span style=\"font-weight: bold\">Modules in eval mode</span>: 0                                                                                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 6.4 K                                                                                            \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 6.4 K                                                                                                \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 0                                                                          \n",
       "\u001b[1mModules in train mode\u001b[0m: 111                                                                                         \n",
       "\u001b[1mModules in eval mode\u001b[0m: 0                                                                                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2e27939fc57d4c9585a3252b035e74f8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:09</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">766</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">692</span><span style=\"font-weight: bold\">}</span> - INFO - Training the model completed                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:09\u001b[0m,\u001b[1;36m766\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m692\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:09:09</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">767</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1533</span><span style=\"font-weight: bold\">}</span> - INFO - Loading the best model                     \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:09:09\u001b[0m,\u001b[1;36m767\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1533\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model                     \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8a98c3a2c4ce4bcaac279982ec86bd8f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.4346666634082794     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    1.1570961475372314     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_loss_0        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    1.1570961475372314     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.4346666634082794    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   1.1570961475372314    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_loss_0       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   1.1570961475372314    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "TabNet Metrics:\n",
      "[{'test_loss_0': 1.1570961475372314, 'test_loss': 1.1570961475372314, 'test_accuracy': 0.4346666634082794}]\n"
     ]
    }
   ],
   "source": [
    "ce_metrics = train_and_evaluate_model(model_config_1, \"Category Embedding\")[0]\n",
    "ft_metrics = train_and_evaluate_model(model_config_2, \"FT Transformer\")[0]\n",
    "tab_metrics = train_and_evaluate_model(model_config_3, \"TabNet\")[0]\n",
    "ce_acc = ce_metrics[\"test_accuracy\"]\n",
    "ft_acc = ft_metrics[\"test_accuracy\"]\n",
    "tab_acc = tab_metrics[\"test_accuracy\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Stacking Model Test Accuracy: 0.5960000157356262\n",
      "Category Embedding Model Test Accucacy: 0.4586666524410248\n",
      "FT Transformer Model Test Accuracy: 0.5546666383743286\n",
      "TabNet Model Test Accuracy: 0.4346666634082794\n"
     ]
    }
   ],
   "source": [
    "print(\"Stacking Model Test Accuracy: {}\".format(stacking_acc))\n",
    "print(\"Category Embedding Model Test Accucacy: {}\".format(ce_acc))\n",
    "print(\"FT Transformer Model Test Accuracy: {}\".format(ft_acc))\n",
    "print(\"TabNet Model Test Accuracy: {}\".format(tab_acc))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save the stacking model & load it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:00:31</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">524</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1579</span><span style=\"font-weight: bold\">}</span> - WARNING - Directory is not empty. Overwriting the \n",
       "contents.                                                                                                          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:00:31\u001b[0m,\u001b[1;36m524\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1579\u001b[0m\u001b[1m}\u001b[0m - WARNING - Directory is not empty. Overwriting the \n",
       "contents.                                                                                                          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "stacking_model.save_model(\"stacking_model\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:00:32</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">437</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">172</span><span style=\"font-weight: bold\">}</span> - INFO - Experiment Tracking is turned off           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:00:32\u001b[0m,\u001b[1;36m437\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m172\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">00:00:32</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">452</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">343</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Trainer                       \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m12\u001b[0m \u001b[1;92m00:00:32\u001b[0m,\u001b[1;36m452\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m343\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer                       \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "loaded_model = TabularModel.load_model(\"stacking_model\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "## Key Points About Stacking\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "1. The stacking model combines predictions from multiple base models into a final prediction\n",
    "2. Each base model can have its own architecture and hyperparameters\n",
    "3. The head layer combines the outputs from all base models\n",
    "4. Base models are trained simultaneously\n",
    "5. The stacking model can often achieve better performance than individual models\n",
    "\n",
    "## Tips for Better Stacking Results\n",
    "\n",
    "1. Use diverse base models that capture different aspects of the data\n",
    "2. Experiment with different head architectures\n",
    "3. Consider using cross-validation for more robust stacking\n",
    "4. Balance model complexity with training time\n",
    "5. Monitor individual model performances to ensure they contribute meaningfully\n",
    "\n",
    "This example demonstrates basic stacking functionality. For production use cases, you may want to:\n",
    "- Use cross-validation\n",
    "- Implement more sophisticated ensemble techniques\n",
    "- Add custom metrics\n",
    "- Tune hyperparameters for both base models and stacking head"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
